{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we begin, we will change a few settings to make the notebook look a bit prettier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "%%html\n",
    "<style> body {font-family: \"Calibri\", cursive, sans-serif;} </style>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# 02 - Treatment Recommendation\n",
    "One of the original DeepSurv's coolest features is that it can be used as a \n",
    "personalized treatment recommender. In this notebook, we will see how this \n",
    "works in DeepSurvK.\n",
    "\n",
    "This notebook assumes that you have gone through the [basics of DeepSurv](./00_understanding_deepsurv.ipynb)\n",
    "as well as [DeepSurvK's basic usage](./01_deepsurvk_quickstart.ipynb)\n",
    "\n",
    "## Preliminaries\n",
    "\n",
    "Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import deepsurvk\n",
    "from deepsurvk.datasets import load_rgbsg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit model\n",
    "The first step is to generate and fit a DeepSurvK model.\n",
    "We will do so in the same manner as we did before.\n",
    "\n",
    "### Get data\n",
    "We will use the RGBSG dataset, since this is the one that was used as an\n",
    "example in the original paper (p. 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, Y_train, E_train = load_rgbsg(partition='training')\n",
    "X_test, Y_test, E_test = load_rgbsg(partition='testing')\n",
    "\n",
    "# Calculate important parameters.\n",
    "n_patients_train = X_train.shape[0]\n",
    "n_features = X_train.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-process data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# Standardization\n",
    "cols_standardize = ['grade', 'age', 'n_positive_nodes', 'progesterone', 'estrogen']\n",
    "X_ct = ColumnTransformer([('standardizer', StandardScaler(), cols_standardize)])\n",
    "X_ct.fit(X_train[cols_standardize])\n",
    "\n",
    "X_train[cols_standardize] = X_ct.transform(X_train[cols_standardize])\n",
    "X_test[cols_standardize] = X_ct.transform(X_test[cols_standardize])\n",
    "\n",
    "Y_scaler = StandardScaler().fit(Y_train)\n",
    "Y_train['T'] = Y_scaler.transform(Y_train)\n",
    "Y_test['T'] = Y_scaler.transform(Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# Sorting\n",
    "sort_idx = np.argsort(Y_train.to_numpy(), axis=None)[::-1]\n",
    "X_train = X_train.loc[sort_idx, :]\n",
    "Y_train = Y_train.loc[sort_idx, :]\n",
    "E_train = E_train.loc[sort_idx, :]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DeepSurvK modelling\n",
    "We will use the parameters that correspond to the RGBSG dataset,\n",
    "as reported in Table 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {'n_layers':1,\n",
    "          'n_nodes':8,\n",
    "          'activation':'selu',\n",
    "          'learning_rate':0.154,\n",
    "          'decays':5.667e-3,\n",
    "          'momentum':0.887,\n",
    "          'l2_reg':6.551,\n",
    "          'dropout':0.661,\n",
    "          'optimizer':'nadam'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsk = deepsurvk.DeepSurvK(n_features=n_features, \n",
    "                          E=E_train,\n",
    "                          **params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = deepsurvk.negative_log_likelihood(E_train)\n",
    "dsk.compile(loss=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = deepsurvk.common_callbacks()\n",
    "\n",
    "epochs = 1000\n",
    "history = dsk.fit(X_train, Y_train, \n",
    "                  batch_size=n_patients_train,\n",
    "                  epochs=epochs, \n",
    "                  callbacks=callbacks,\n",
    "                  shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepsurvk.plot_loss(history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# Perform predictions for test data (sanity check)\n",
    "Y_pred_test = np.exp(-dsk.predict(X_test))\n",
    "c_index_test = deepsurvk.concordance_index(Y_test, Y_pred_test, E_test)\n",
    "print(f\"c-index of testing dataset = {c_index_test}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Treatment recommendation\n",
    "The original paper has a very clear explanation of how the treatment\n",
    "recommender system works, which is as follows:\n",
    "\n",
    "> We assume each treatment $i$ to have an independent risk function \n",
    "$e^{h_i(x)}$. [...] For any patient, the [model] should be able to \n",
    "accurately predict the log-risk $h_i(x)$ of being prescribed a given\n",
    "treatment $i$. Then, based on the assumption that each individual\n",
    "has the same baseline hazard function $\\lambda_0(t)$, we\n",
    "can take the log of the hazards ratio to calculate the personal\n",
    "risk-ratio of prescribing one treatment option over\n",
    "another. We define this difference of log hazards as the\n",
    "recommender function $rec_{ij}(x)$:\n",
    "\n",
    "$rec_{ij}(x) = \\log \\left( \\frac{\\lambda_0(t) e^{h_i(x)}}{\\lambda_0(t) e^{h_j(x)}} \\right) $\n",
    "\n",
    "$rec_{ij}(x) = h_i(x) - h_j(x) $\n",
    "\n",
    "DeepSurvK provides the function `recommender_function`, which \n",
    "allows calculating $rec_{ij}(x)$ in a very easy way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rec_ij = deepsurvk.recommender_function(dsk, X_test, 'horm_treatment')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> The recommender function can be used to provide personalized treatment \n",
    "recommendations. We first pass a patient through the network once in \n",
    "treatment group $i$ and again in treatment group $j$ and take the \n",
    "difference. When a patient receives a positive recommendation $rec_{ij}(x)$,\n",
    "treatment $i$ leads to a higher risk of death than treatment $j$. Hence, \n",
    "the patient should be prescribed treatment $j$. Conversely, a negative \n",
    "recommendation indicates that treatment $i$ is more effective and leads to \n",
    "a lower risk of death than treatment $j$, and we recommend treatment $i$.\n",
    "\n",
    "DeepSurv also has a function to find these subsets of patients \n",
    "(recommendation and anti-recommendation):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "recommendation_idx, _ = deepsurvk.get_recs_antirecs_index(rec_ij, X_test, 'horm_treatment')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`get_recs_antirecs_index` gives as a second output `antirecommendation_idx`.\n",
    "However, it is nothing else than the negated version of \n",
    "`recommendation_idx`. Therefore, we will ignore the former and stick with\n",
    "the later.\n",
    "\n",
    "Finally, we can generate Kaplan-Meier (KM) curves for each patient group.\n",
    "To do so, first we need to invert the transformation we did previously\n",
    "on `Y_test` (to bring it back to proper time units)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_test_original = Y_test.copy(deep=True)\n",
    "Y_test_original['T'] = Y_scaler.inverse_transform(Y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DeepSurvK provides a function to quickly generate such plot.\n",
    "Notice how this visualization pretty much matches Fig. 6a of\n",
    "the original paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepsurvk.plot_km_recs_antirecs(Y_test_original, E_test, recommendation_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the KM curve of the patients that were treated according\n",
    "to the model's recommendation is higher than that of patients that were\n",
    "*not*. This is confirmed by the log-rank statistic with a $p$-value\n",
    "of 0.0034."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:percent"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
